// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/socket/transport_client_socket_pool_test_util.h"

#include <stdint.h>
#include <string>
#include <utility>

#include "base/location.h"
#include "base/logging.h"
#include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "base/run_loop.h"
#include "base/single_thread_task_runner.h"
#include "base/threading/thread_task_runner_handle.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
#include "net/base/load_timing_info.h"
#include "net/base/load_timing_info_test_util.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/ssl_client_socket.h"
#include "net/udp/datagram_client_socket.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace net {

namespace {

    IPAddress ParseIP(const std::string& ip)
    {
        IPAddress address;
        CHECK(address.AssignFromIPLiteral(ip));
        return address;
    }

    // A StreamSocket which connects synchronously and successfully.
    class MockConnectClientSocket : public StreamSocket {
    public:
        MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log)
            : connected_(false)
            , addrlist_(addrlist)
            , net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET))
        {
        }

        // StreamSocket implementation.
        int Connect(const CompletionCallback& callback) override
        {
            connected_ = true;
            return OK;
        }
        void Disconnect() override { connected_ = false; }
        bool IsConnected() const override { return connected_; }
        bool IsConnectedAndIdle() const override { return connected_; }

        int GetPeerAddress(IPEndPoint* address) const override
        {
            *address = addrlist_.front();
            return OK;
        }
        int GetLocalAddress(IPEndPoint* address) const override
        {
            if (!connected_)
                return ERR_SOCKET_NOT_CONNECTED;
            if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
                SetIPv4Address(address);
            else
                SetIPv6Address(address);
            return OK;
        }
        const BoundNetLog& NetLog() const override { return net_log_; }

        void SetSubresourceSpeculation() override { }
        void SetOmniboxSpeculation() override { }
        bool WasEverUsed() const override { return false; }
        void EnableTCPFastOpenIfSupported() override { }
        bool WasNpnNegotiated() const override { return false; }
        NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
        bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
        void GetConnectionAttempts(ConnectionAttempts* out) const override
        {
            out->clear();
        }
        void ClearConnectionAttempts() override { }
        void AddConnectionAttempts(const ConnectionAttempts& attempts) override { }
        int64_t GetTotalReceivedBytes() const override
        {
            NOTIMPLEMENTED();
            return 0;
        }

        // Socket implementation.
        int Read(IOBuffer* buf,
            int buf_len,
            const CompletionCallback& callback) override
        {
            return ERR_FAILED;
        }
        int Write(IOBuffer* buf,
            int buf_len,
            const CompletionCallback& callback) override
        {
            return ERR_FAILED;
        }
        int SetReceiveBufferSize(int32_t size) override { return OK; }
        int SetSendBufferSize(int32_t size) override { return OK; }

    private:
        bool connected_;
        const AddressList addrlist_;
        BoundNetLog net_log_;

        DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket);
    };

    class MockFailingClientSocket : public StreamSocket {
    public:
        MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log)
            : addrlist_(addrlist)
            , net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET))
        {
        }

        // StreamSocket implementation.
        int Connect(const CompletionCallback& callback) override
        {
            return ERR_CONNECTION_FAILED;
        }

        void Disconnect() override { }

        bool IsConnected() const override { return false; }
        bool IsConnectedAndIdle() const override { return false; }
        int GetPeerAddress(IPEndPoint* address) const override
        {
            return ERR_UNEXPECTED;
        }
        int GetLocalAddress(IPEndPoint* address) const override
        {
            return ERR_UNEXPECTED;
        }
        const BoundNetLog& NetLog() const override { return net_log_; }

        void SetSubresourceSpeculation() override { }
        void SetOmniboxSpeculation() override { }
        bool WasEverUsed() const override { return false; }
        void EnableTCPFastOpenIfSupported() override { }
        bool WasNpnNegotiated() const override { return false; }
        NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
        bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
        void GetConnectionAttempts(ConnectionAttempts* out) const override
        {
            out->clear();
            for (const auto& addr : addrlist_)
                out->push_back(ConnectionAttempt(addr, ERR_CONNECTION_FAILED));
        }
        void ClearConnectionAttempts() override { }
        void AddConnectionAttempts(const ConnectionAttempts& attempts) override { }
        int64_t GetTotalReceivedBytes() const override
        {
            NOTIMPLEMENTED();
            return 0;
        }

        // Socket implementation.
        int Read(IOBuffer* buf,
            int buf_len,
            const CompletionCallback& callback) override
        {
            return ERR_FAILED;
        }

        int Write(IOBuffer* buf,
            int buf_len,
            const CompletionCallback& callback) override
        {
            return ERR_FAILED;
        }
        int SetReceiveBufferSize(int32_t size) override { return OK; }
        int SetSendBufferSize(int32_t size) override { return OK; }

    private:
        const AddressList addrlist_;
        BoundNetLog net_log_;

        DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket);
    };

    class MockTriggerableClientSocket : public StreamSocket {
    public:
        // |should_connect| indicates whether the socket should successfully complete
        // or fail.
        MockTriggerableClientSocket(const AddressList& addrlist,
            bool should_connect,
            net::NetLog* net_log)
            : should_connect_(should_connect)
            , is_connected_(false)
            , addrlist_(addrlist)
            , net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET))
            , weak_factory_(this)
        {
        }

        // Call this method to get a closure which will trigger the connect callback
        // when called. The closure can be called even after the socket is deleted; it
        // will safely do nothing.
        base::Closure GetConnectCallback()
        {
            return base::Bind(&MockTriggerableClientSocket::DoCallback,
                weak_factory_.GetWeakPtr());
        }

        static std::unique_ptr<StreamSocket> MakeMockPendingClientSocket(
            const AddressList& addrlist,
            bool should_connect,
            net::NetLog* net_log)
        {
            std::unique_ptr<MockTriggerableClientSocket> socket(
                new MockTriggerableClientSocket(addrlist, should_connect, net_log));
            base::ThreadTaskRunnerHandle::Get()->PostTask(FROM_HERE,
                socket->GetConnectCallback());
            return std::move(socket);
        }

        static std::unique_ptr<StreamSocket> MakeMockDelayedClientSocket(
            const AddressList& addrlist,
            bool should_connect,
            const base::TimeDelta& delay,
            net::NetLog* net_log)
        {
            std::unique_ptr<MockTriggerableClientSocket> socket(
                new MockTriggerableClientSocket(addrlist, should_connect, net_log));
            base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
                FROM_HERE, socket->GetConnectCallback(), delay);
            return std::move(socket);
        }

        static std::unique_ptr<StreamSocket> MakeMockStalledClientSocket(
            const AddressList& addrlist,
            net::NetLog* net_log,
            bool failing)
        {
            std::unique_ptr<MockTriggerableClientSocket> socket(
                new MockTriggerableClientSocket(addrlist, true, net_log));
            if (failing) {
                DCHECK_LE(1u, addrlist.size());
                ConnectionAttempts attempts;
                attempts.push_back(ConnectionAttempt(addrlist[0], ERR_CONNECTION_FAILED));
                socket->AddConnectionAttempts(attempts);
            }
            return std::move(socket);
        }

        // StreamSocket implementation.
        int Connect(const CompletionCallback& callback) override
        {
            DCHECK(callback_.is_null());
            callback_ = callback;
            return ERR_IO_PENDING;
        }

        void Disconnect() override { }

        bool IsConnected() const override { return is_connected_; }
        bool IsConnectedAndIdle() const override { return is_connected_; }
        int GetPeerAddress(IPEndPoint* address) const override
        {
            *address = addrlist_.front();
            return OK;
        }
        int GetLocalAddress(IPEndPoint* address) const override
        {
            if (!is_connected_)
                return ERR_SOCKET_NOT_CONNECTED;
            if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
                SetIPv4Address(address);
            else
                SetIPv6Address(address);
            return OK;
        }
        const BoundNetLog& NetLog() const override { return net_log_; }

        void SetSubresourceSpeculation() override { }
        void SetOmniboxSpeculation() override { }
        bool WasEverUsed() const override { return false; }
        void EnableTCPFastOpenIfSupported() override { }
        bool WasNpnNegotiated() const override { return false; }
        NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
        bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
        void GetConnectionAttempts(ConnectionAttempts* out) const override
        {
            *out = connection_attempts_;
        }
        void ClearConnectionAttempts() override { connection_attempts_.clear(); }
        void AddConnectionAttempts(const ConnectionAttempts& attempts) override
        {
            connection_attempts_.insert(connection_attempts_.begin(), attempts.begin(),
                attempts.end());
        }
        int64_t GetTotalReceivedBytes() const override
        {
            NOTIMPLEMENTED();
            return 0;
        }

        // Socket implementation.
        int Read(IOBuffer* buf,
            int buf_len,
            const CompletionCallback& callback) override
        {
            return ERR_FAILED;
        }

        int Write(IOBuffer* buf,
            int buf_len,
            const CompletionCallback& callback) override
        {
            return ERR_FAILED;
        }
        int SetReceiveBufferSize(int32_t size) override { return OK; }
        int SetSendBufferSize(int32_t size) override { return OK; }

    private:
        void DoCallback()
        {
            is_connected_ = should_connect_;
            callback_.Run(is_connected_ ? OK : ERR_CONNECTION_FAILED);
        }

        bool should_connect_;
        bool is_connected_;
        const AddressList addrlist_;
        BoundNetLog net_log_;
        CompletionCallback callback_;
        ConnectionAttempts connection_attempts_;

        base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_;

        DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket);
    };

} // namespace

void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle)
{
    LoadTimingInfo load_timing_info;
    // Only pass true in as |is_reused|, as in general, HttpStream types should
    // have stricter concepts of reuse than socket pools.
    EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));

    EXPECT_TRUE(load_timing_info.socket_reused);
    EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);

    ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
    ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
}

void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle)
{
    EXPECT_FALSE(handle.is_reused());

    LoadTimingInfo load_timing_info;
    EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));

    EXPECT_FALSE(load_timing_info.socket_reused);
    EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);

    ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
        CONNECT_TIMING_HAS_DNS_TIMES);
    ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);

    TestLoadTimingInfoConnectedReused(handle);
}

void SetIPv4Address(IPEndPoint* address)
{
    *address = IPEndPoint(ParseIP("1.1.1.1"), 80);
}

void SetIPv6Address(IPEndPoint* address)
{
    *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
}

MockTransportClientSocketFactory::MockTransportClientSocketFactory(
    NetLog* net_log)
    : net_log_(net_log)
    , allocation_count_(0)
    , client_socket_type_(MOCK_CLIENT_SOCKET)
    , client_socket_types_(NULL)
    , client_socket_index_(0)
    , client_socket_index_max_(0)
    , delay_(base::TimeDelta::FromMilliseconds(
          ClientSocketPool::kMaxConnectRetryIntervalMs))
{
}

MockTransportClientSocketFactory::~MockTransportClientSocketFactory() { }

std::unique_ptr<DatagramClientSocket>
MockTransportClientSocketFactory::CreateDatagramClientSocket(
    DatagramSocket::BindType bind_type,
    const RandIntCallback& rand_int_cb,
    NetLog* net_log,
    const NetLog::Source& source)
{
    NOTREACHED();
    return std::unique_ptr<DatagramClientSocket>();
}

std::unique_ptr<StreamSocket>
MockTransportClientSocketFactory::CreateTransportClientSocket(
    const AddressList& addresses,
    std::unique_ptr<SocketPerformanceWatcher> /* socket_performance_watcher */,
    NetLog* /* net_log */,
    const NetLog::Source& /* source */)
{
    allocation_count_++;

    ClientSocketType type = client_socket_type_;
    if (client_socket_types_ && client_socket_index_ < client_socket_index_max_) {
        type = client_socket_types_[client_socket_index_++];
    }

    switch (type) {
    case MOCK_CLIENT_SOCKET:
        return std::unique_ptr<StreamSocket>(
            new MockConnectClientSocket(addresses, net_log_));
    case MOCK_FAILING_CLIENT_SOCKET:
        return std::unique_ptr<StreamSocket>(
            new MockFailingClientSocket(addresses, net_log_));
    case MOCK_PENDING_CLIENT_SOCKET:
        return MockTriggerableClientSocket::MakeMockPendingClientSocket(
            addresses, true, net_log_);
    case MOCK_PENDING_FAILING_CLIENT_SOCKET:
        return MockTriggerableClientSocket::MakeMockPendingClientSocket(
            addresses, false, net_log_);
    case MOCK_DELAYED_CLIENT_SOCKET:
        return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
            addresses, true, delay_, net_log_);
    case MOCK_DELAYED_FAILING_CLIENT_SOCKET:
        return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
            addresses, false, delay_, net_log_);
    case MOCK_STALLED_CLIENT_SOCKET:
        return MockTriggerableClientSocket::MakeMockStalledClientSocket(
            addresses, net_log_, false);
    case MOCK_STALLED_FAILING_CLIENT_SOCKET:
        return MockTriggerableClientSocket::MakeMockStalledClientSocket(
            addresses, net_log_, true);
    case MOCK_TRIGGERABLE_CLIENT_SOCKET: {
        std::unique_ptr<MockTriggerableClientSocket> rv(
            new MockTriggerableClientSocket(addresses, true, net_log_));
        triggerable_sockets_.push(rv->GetConnectCallback());
        // run_loop_quit_closure_ behaves like a condition variable. It will
        // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
        // don't need to worry about atomicity because this code is
        // single-threaded.
        if (!run_loop_quit_closure_.is_null())
            run_loop_quit_closure_.Run();
        return std::move(rv);
    }
    default:
        NOTREACHED();
        return std::unique_ptr<StreamSocket>(
            new MockConnectClientSocket(addresses, net_log_));
    }
}

std::unique_ptr<SSLClientSocket>
MockTransportClientSocketFactory::CreateSSLClientSocket(
    std::unique_ptr<ClientSocketHandle> transport_socket,
    const HostPortPair& host_and_port,
    const SSLConfig& ssl_config,
    const SSLClientSocketContext& context)
{
    NOTIMPLEMENTED();
    return std::unique_ptr<SSLClientSocket>();
}

void MockTransportClientSocketFactory::ClearSSLSessionCache()
{
    NOTIMPLEMENTED();
}

void MockTransportClientSocketFactory::set_client_socket_types(
    ClientSocketType* type_list,
    int num_types)
{
    DCHECK_GT(num_types, 0);
    client_socket_types_ = type_list;
    client_socket_index_ = 0;
    client_socket_index_max_ = num_types;
}

base::Closure
MockTransportClientSocketFactory::WaitForTriggerableSocketCreation()
{
    while (triggerable_sockets_.empty()) {
        base::RunLoop run_loop;
        run_loop_quit_closure_ = run_loop.QuitClosure();
        run_loop.Run();
        run_loop_quit_closure_.Reset();
    }
    base::Closure trigger = triggerable_sockets_.front();
    triggerable_sockets_.pop();
    return trigger;
}

} // namespace net
